package com.fengwk.deeplearning.core;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;

class SigmoidUnitCompute implements UnitCompute {
	
	private static final long serialVersionUID = -8818410435252408169L;
	
	private INDArray A;// 缓存
	
	@Override
	public INDArray activate(INDArray Z) {
		A = Transforms.sigmoid(Z);
		// 修正0和1
		for (int i = 0; i < A.length(); i ++) {
			double v = A.getDouble(i);
			if (v == 1d)
				A.putScalar(i, 0.9999);
			if (v == 0d)
				A.putScalar(i, 0.0001);
		}
		return A;
	}

	@Override
	public INDArray dActivate(INDArray Z) {
		return A.mul(A.rsub(1));
	}

}
